import gfootball.env as football_env
from gfootball.env import observation_preprocessing
import gym
import numpy as np
from env.multiagent.environment import MultiAgentEnv
import env.multiagent.scenarios as scenarios
import matplotlib.pyplot as plt


class MpeEnv(object):
    """An wrapper for GFootball to make it compatible with our codebase."""

    def __init__(self, args, dense_reward, dump_freq, render=False):


        # load scenario from script
        scenario = scenarios.load(args.env).Scenario()
        # create world
        world = scenario.make_world()
        scenario.seed(args.seed)

        # create multiagent environment
        self.env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, info_callback=None,
                            shared_viewer=False)


        #self.env.seed(args.seed)
        self.n_agents = self.env.n
        self.time_limit = 40
        self.n_actions=5
        self.time_step = 0
        self.obs_dim = self.env.observation_space[0].shape[0]  # for counterattack_easy 4 vs 2
        self.state_shape = self.n_agents*self.obs_dim
        self.n_enemy = 2
        self.p_state_dim=6
        self.dense_reward = dense_reward  # select whether to use dense reward
        self.get_ball_rew = False
        self.scale=1000



    def get_p_state(self):
        return self.p_state




    def _encode_ball_which_zone(self, ball_x, ball_y):
        MIDDLE_X, PENALTY_X, END_X = 0.2, 0.64, 1.0
        PENALTY_Y, END_Y = 0.27, 0.42
        if (-END_X <= ball_x and ball_x < -PENALTY_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
            return [1.0, 0, 0, 0, 0, 0]
        elif (-END_X <= ball_x and ball_x < -MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 1.0, 0, 0, 0, 0]
        elif (-MIDDLE_X <= ball_x and ball_x <= MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 0, 1.0, 0, 0, 0]
        elif (PENALTY_X < ball_x and ball_x <= END_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
            return [0, 0, 0, 1.0, 0, 0]
        elif (MIDDLE_X < ball_x and ball_x <= END_X) and (-END_Y < ball_y and ball_y < END_Y):
            return [0, 0, 0, 0, 1.0, 0]
        else:
            return [0, 0, 0, 0, 0, 1.0]

    def reset(self):
        self.time_step = 0
        self.get_ball_rew = False


        obs = self.env.reset()
        obs=np.array(obs)/self.scale
        self.p_state=obs[...,:self.p_state_dim]
        return obs, self.get_global_state()

    def check_if_done(self):
        cur_obs = self.env.unwrapped.observation()[0]
        ball_loc = cur_obs['ball']
        ours_loc = cur_obs['left_team'][-self.n_agents:]

        if ball_loc[0] < 0 or any(ours_loc[:, 0] < 0):
            return True

        return False

    def step(self, actions):
        action_onehot=np.zeros((self.n_agents,self.n_actions))
        action_onehot[range(self.n_agents),actions]=1
        self.time_step += 1
        new_obs_n, rew_n, done_n, info_ns = self.env.step(action_onehot.astype('int'))
        #self.obs=new_obs_n
        return np.array(rew_n).mean(), done_n[0], info_ns

    def seed(self, seed):
        self.env.seed(seed)

    def close(self):
        self.env.close()

    def get_global_state(self):
        obs = np.array(self.env.get_obs())/self.scale

        return obs.reshape(-1)
    def get_state(self):
        obs = np.array(self.env.get_obs())/self.scale

        return obs.reshape(-1)
    def get_obs(self):
        """ Returns all agent observations in a list """
        obs = np.array(self.env.get_obs())/self.scale
        self.p_state=obs[...,:self.p_state_dim]

        return obs

    def get_obs_size(self):
        """ Returns the shape of the observation """

        return self.obs_dim

    def get_state_size(self):
        """ Returns the shape of the state"""

        return self.state_shape

    def get_avail_agent_actions(self, id):
        avail_actions = np.ones((self.n_actions,))

        return avail_actions

    def get_env_info(self):
        output_dict = {}
        output_dict['n_actions'] = self.n_actions
        output_dict['obs_shape'] = self.obs_dim
        output_dict['n_agents'] = self.n_agents
        output_dict['state_shape'] = self.state_shape
        output_dict['episode_limit'] = self.time_limit
        output_dict['n_enemy'] = self.n_enemy
        output_dict['p_state'] = 6
        return output_dict


# def make_football_env(seed_dir, dump_freq=1000, representation='extracted', render=False):
def make_mpe_env(args, dense_reward=False, dump_freq=1000, render=False):
    '''
    Creates a env object. This can be used similar to a gym
    environment by calling env.reset() and env.step().

    Some useful env properties:
        .observation_space  :   Returns the observation space for each agent
        .action_space       :   Returns the action space for each agent
        .nagents            :   Returns the number of Agents
    '''
    return MpeEnv(args, dense_reward, dump_freq, render)
